{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnx import numpy_helper\n",
    "\n",
    "def f(t):\n",
    "    return [f(i) for i in t] if isinstance(t, (list, tuple)) else t\n",
    "\n",
    "def g(t, res):\n",
    "    for i in t:\n",
    "        res.append(i) if not isinstance(i, (list, tuple)) else g(i, res)\n",
    "    return res\n",
    "\n",
    "def SaveData(test_data_dir, prefix, data_list):\n",
    "    if isinstance(data_list, torch.autograd.Variable) or isinstance(data_list, torch.Tensor):\n",
    "        data_list = [data_list]\n",
    "    for i, d in enumerate(data_list):\n",
    "        d = d.data.cpu().numpy()\n",
    "        SaveTensorProto(os.path.join(test_data_dir, '{0}_{1}.pb'.format(prefix, i)), prefix + str(i+1), d)\n",
    "        \n",
    "def SaveTensorProto(file_path, name, data):\n",
    "    tp = numpy_helper.from_array(data)\n",
    "    tp.name = name\n",
    "\n",
    "    with open(file_path, 'wb') as f:\n",
    "        f.write(tp.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import re\n",
    "import os\n",
    "import onnxruntime as rt\n",
    "from transformer_net import TransformerNet\n",
    "\n",
    "input = torch.randn(1, 3, 224, 224)\n",
    "with torch.no_grad():\n",
    "    model = TransformerNet()\n",
    "    model_dict = torch.load(\"PATH TO PYTORCH MODEL\")\n",
    "    for k in list(model_dict.keys()):\n",
    "        if re.search(r'in\\d+\\.running_(mean|var)$', k):\n",
    "            del model_dict[k]\n",
    "    model.load_state_dict(model_dict)\n",
    "    output = model(input)\n",
    "    \n",
    "input_names = ['input1']\n",
    "output_names = ['output1']\n",
    "dir = \"PATH TO CONVERTED ONNX MODEL\"\n",
    "if not os.path.exists(dir):\n",
    "    os.makedirs(dir)\n",
    "data_dir = os.path.join(dir, \"data_set\")\n",
    "if not os.path.exists(data_dir):\n",
    "    os.makedirs(data_dir)\n",
    "\n",
    "if isinstance(model, torch.jit.ScriptModule):\n",
    "    torch.onnx._export(model, tuple((input,)), os.path.join(dir, 'model.onnx'), verbose=True, input_names=input_names, output_names=output_names, example_outputs=(output,))\n",
    "else:\n",
    "    torch.onnx.export(model, tuple((input,)), os.path.join(dir, 'model.onnx'), verbose=True, input_names=input_names, output_names=output_names)\n",
    "\n",
    "input = f(input)\n",
    "input = g(input, [])\n",
    "output = f(output)\n",
    "output = g(output, [])\n",
    "        \n",
    "SaveData(data_dir, 'input', input)\n",
    "SaveData(data_dir, 'output', output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
